from typing import Callable
import paramiko
from io import StringIO
from paramiko.client import SSHClient, AutoAddPolicy
from paramiko.rsakey import RSAKey

DEFAULT_CONNENT_TIMEOUT = 5    # 默认ssh链接超时时间 5s
DEFAULT_NODE_USER = 'root'     # 默认节点用户名 root


class SSH:
    """A SSH client used to run command in remote node

    args:
        hostname(str): Host name

    Keyword Args:
        username(str): User name, default 'root'
        port(str): SSH communicate port, default 22
        connect_timeout(int): Connection timeout duration, default 5s
        password(str)
    """

    # key_pair cached the key pair generated by initialization stage
    _key_pair = {}
    _private_key_getter: Callable[[], str] = None
    _public_key_getter: Callable[[], str] = None

    def __init__(self, hostname: str, **kwargs) -> None:
        self.connect_args = {
            'hostname': hostname,
            'username': kwargs.get('username', DEFAULT_NODE_USER),
            'port': kwargs.get('port', 22),
            'timeout': kwargs.get('timeout', DEFAULT_CONNENT_TIMEOUT),
        }
        if 'password' in kwargs and kwargs['password'] is not None:
            self.connect_args['password'] = kwargs.get('password')
        else:
            if SSH._private_key_getter is None:
                raise Exception("_private_key_getter not set")
            self.connect_args['pkey'] = RSAKey.from_private_key(
                StringIO(SSH._private_key_getter())
            )

        self._client: SSHClient = self.client()

    def client(self):
        try:
            client = SSHClient()
            client.set_missing_host_key_policy(AutoAddPolicy)
            client.connect(**self.connect_args)
            return client
        except paramiko.AuthenticationException:
            raise Exception('authorization fail, password or pkey error!')
        except:
            raise Exception('authorization fail!')

    @classmethod
    def set_private_key_getter(cls, private_key_getter: Callable[[], str]):
        cls._private_key_getter = private_key_getter

    @classmethod
    def set_public_key_getter(cls, public_key_getter: Callable[[], str]):
        cls._public_key_getter = public_key_getter

    def run_command(self, command):
        if self._client:
            ssh_session = self._client.get_transport().open_session()
            ssh_session.set_combine_stderr(True)
            ssh_session.exec_command(command)
            stdout = ssh_session.makefile("rb", -1)
            statue = ssh_session.recv_exit_status()
            output = stdout.read().decode()
            return statue, output
        else:
            raise Exception('No client!')

    def add_public_key(self):
        if self._public_key_getter is None:
            raise Exception("_public_key_getter not set")
        public_key = SSH._public_key_getter()
        command = f'mkdir -p -m 700 ~/.ssh && \
        echo {public_key!r} >> ~/.ssh/authorized_keys && \
        chmod 600 ~/.ssh/authorized_keys'
        statue, _ = self.run_command(command)
        if statue != 0:
            raise Exception('add public key faild!')

    @staticmethod
    def validate_ssh_host(ip: str, password: str, port: int = 22, username: str = 'root'):
        try:
            ssh = SSH(hostname=ip, password=password,
                      port=port, username=username, timeout=2)
            ssh.add_public_key()
            return True, 'authorization success'
        except Exception as e:
            return False, f'error: {e}'
